{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "!pip install -Uqq fastbook\n",
    "import fastbook\n",
    "fastbook.setup_book()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from fastbook import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ResNets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Going Back to Imagenette"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(url, presize, resize):\n",
    "    path = untar_data(url)\n",
    "    return DataBlock(\n",
    "        blocks=(ImageBlock, CategoryBlock), get_items=get_image_files, \n",
    "        splitter=GrandparentSplitter(valid_name='val'),\n",
    "        get_y=parent_label, item_tfms=Resize(presize),\n",
    "        batch_tfms=[*aug_transforms(min_scale=0.5, size=resize),\n",
    "                    Normalize.from_stats(*imagenet_stats)],\n",
    "    ).dataloaders(path, bs=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = get_data(URLs.IMAGENETTE_160, 160, 128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls.show_batch(max_n=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_pool(x): return x.mean((2,3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def block(ni, nf): return ConvLayer(ni, nf, stride=2)\n",
    "def get_model():\n",
    "    return nn.Sequential(\n",
    "        block(3, 16),\n",
    "        block(16, 32),\n",
    "        block(32, 64),\n",
    "        block(64, 128),\n",
    "        block(128, 256),\n",
    "        nn.AdaptiveAvgPool2d(1),\n",
    "        Flatten(),\n",
    "        nn.Linear(256, dls.c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_learner(m):\n",
    "    return Learner(dls, m, loss_func=nn.CrossEntropyLoss(), metrics=accuracy\n",
    "                  ).to_fp16()\n",
    "\n",
    "learn = get_learner(get_model())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(5, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building a Modern CNN: ResNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Skip Connections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResBlock(Module):\n",
    "    def __init__(self, ni, nf):\n",
    "        self.convs = nn.Sequential(\n",
    "            ConvLayer(ni,nf),\n",
    "            ConvLayer(nf,nf, norm_type=NormType.BatchZero))\n",
    "        \n",
    "    def forward(self, x): return x + self.convs(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _conv_block(ni,nf,stride):\n",
    "    return nn.Sequential(\n",
    "        ConvLayer(ni, nf, stride=stride),\n",
    "        ConvLayer(nf, nf, act_cls=None, norm_type=NormType.BatchZero))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResBlock(Module):\n",
    "    def __init__(self, ni, nf, stride=1):\n",
    "        self.convs = _conv_block(ni,nf,stride)\n",
    "        self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act_cls=None)\n",
    "        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return F.relu(self.convs(x) + self.idconv(self.pool(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def block(ni,nf): return ResBlock(ni, nf, stride=2)\n",
    "learn = get_learner(get_model())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(5, 3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def block(ni, nf):\n",
    "    return nn.Sequential(ResBlock(ni, nf, stride=2), ResBlock(nf, nf))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = get_learner(get_model())\n",
    "learn.fit_one_cycle(5, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A State-of-the-Art ResNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _resnet_stem(*sizes):\n",
    "    return [\n",
    "        ConvLayer(sizes[i], sizes[i+1], 3, stride = 2 if i==0 else 1)\n",
    "            for i in range(len(sizes)-1)\n",
    "    ] + [nn.MaxPool2d(kernel_size=3, stride=2, padding=1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_resnet_stem(3,32,32,64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet(nn.Sequential):\n",
    "    def __init__(self, n_out, layers, expansion=1):\n",
    "        stem = _resnet_stem(3,32,32,64)\n",
    "        self.block_szs = [64, 64, 128, 256, 512]\n",
    "        for i in range(1,5): self.block_szs[i] *= expansion\n",
    "        blocks = [self._make_layer(*o) for o in enumerate(layers)]\n",
    "        super().__init__(*stem, *blocks,\n",
    "                         nn.AdaptiveAvgPool2d(1), Flatten(),\n",
    "                         nn.Linear(self.block_szs[-1], n_out))\n",
    "    \n",
    "    def _make_layer(self, idx, n_layers):\n",
    "        stride = 1 if idx==0 else 2\n",
    "        ch_in,ch_out = self.block_szs[idx:idx+2]\n",
    "        return nn.Sequential(*[\n",
    "            ResBlock(ch_in if i==0 else ch_out, ch_out, stride if i==0 else 1)\n",
    "            for i in range(n_layers)\n",
    "        ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rn = ResNet(dls.c, [2,2,2,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = get_learner(rn)\n",
    "learn.fit_one_cycle(5, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Bottleneck Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _conv_block(ni,nf,stride):\n",
    "    return nn.Sequential(\n",
    "        ConvLayer(ni, nf//4, 1),\n",
    "        ConvLayer(nf//4, nf//4, stride=stride), \n",
    "        ConvLayer(nf//4, nf, 1, act_cls=None, norm_type=NormType.BatchZero))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = get_data(URLs.IMAGENETTE_320, presize=320, resize=224)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rn = ResNet(dls.c, [3,4,6,3], 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = get_learner(rn)\n",
    "learn.fit_one_cycle(20, 3e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Questionnaire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. How did we get to a single vector of activations in the CNNs used for MNIST in previous chapters? Why isn't that suitable for Imagenette?\n",
    "1. What do we do for Imagenette instead?\n",
    "1. What is \"adaptive pooling\"?\n",
    "1. What is \"average pooling\"?\n",
    "1. Why do we need `Flatten` after an adaptive average pooling layer?\n",
    "1. What is a \"skip connection\"?\n",
    "1. Why do skip connections allow us to train deeper models?\n",
    "1. What does <<resnet_depth>> show? How did that lead to the idea of skip connections?\n",
    "1. What is \"identity mapping\"?\n",
    "1. What is the basic equation for a ResNet block (ignoring batchnorm and ReLU layers)?\n",
    "1. What do ResNets have to do with residuals?\n",
    "1. How do we deal with the skip connection when there is a stride-2 convolution? How about when the number of filters changes?\n",
    "1. How can we express a 1×1 convolution in terms of a vector dot product?\n",
    "1. Create a `1x1 convolution` with `F.conv2d` or `nn.Conv2d` and apply it to an image. What happens to the `shape` of the image?\n",
    "1. What does the `noop` function return?\n",
    "1. Explain what is shown in <<resnet_surface>>.\n",
    "1. When is top-5 accuracy a better metric than top-1 accuracy?\n",
    "1. What is the \"stem\" of a CNN?\n",
    "1. Why do we use plain convolutions in the CNN stem, instead of ResNet blocks?\n",
    "1. How does a bottleneck block differ from a plain ResNet block?\n",
    "1. Why is a bottleneck block faster?\n",
    "1. How do fully convolutional nets (and nets with adaptive pooling in general) allow for progressive resizing?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Further Research"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. Try creating a fully convolutional net with adaptive average pooling for MNIST (note that you'll need fewer stride-2 layers). How does it compare to a network without such a pooling layer?\n",
    "1. In <<chapter_foundations>> we introduce *Einstein summation notation*. Skip ahead to see how this works, and then write an implementation of the 1×1 convolution operation using `torch.einsum`. Compare it to the same operation using `torch.conv2d`.\n",
    "1. Write a \"top-5 accuracy\" function using plain PyTorch or plain Python.\n",
    "1. Train a model on Imagenette for more epochs, with and without label smoothing. Take a look at the Imagenette leaderboards and see how close you can get to the best results shown. Read the linked pages describing the leading approaches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
